CREATE OR REPLACE FUNCTION "pipeline"."kmeans_pp"("source_table" varchar, "out_table_model" varchar, "out_table_result" varchar, "id_col" varchar, "feature_cols" varchar, "k" int4, "fn_dist" varchar, "agg_centroid" varchar, "max_iter" int4, "min_frac" float8)
  RETURNS "pg_catalog"."text" AS $BODY$
import json
import time
import random


def columnsToVecTable(sourceTable, featureCols, otherCols, outTable):
    featureLists = featureCols.split(",")
    tmpList = []
    for item in featureLists:
        tmpList.append('"{}"'.format(item))
    otherLists = []
    for item in otherCols.split(","):
        otherLists.append('"{}"'.format(item))
    sqlStr = "SELECT madlib.cols2vec('{}', '{}', '{}', NULL, '{}')".format(sourceTable, outTable, ",".join(tmpList),
                                                                           ",".join(otherLists))
    msg = "success"
    msg += sqlStr
    try:
        rs = plpy.execute(sqlStr)
    except Exception as e:
        plpy.execute("DROP TABLE IF EXISTS {}".format(outTable))
        plpy.execute("DROP TABLE IF EXISTS {}_summary".format("outTable"))
        msg = "sql=[{}] errorMsg=[{}]".format(sqlStr, str(e))
        return False, msg
    return True, msg


def logError(msg, status, result):
    result["status"] = status
    result["error_msg"] = msg
    return json.dumps(result)


def getTableMetaInfo(table_name):
    tmps = table_name.split(".")
    if len(tmps) > 0:
        table_name = tmps[1]
    sql_str = "select column_name, data_type from information_schema.columns where table_name = '%s'" % (table_name)
    rs = plpy.execute(sql_str)
    meta = {}
    for line in rs:
        meta[line['column_name']] = line['data_type']
    return meta


def begin_alg(source_table, out_table_model, out_table_result, id_col, feature_cols, k, fn_dist, agg_centroid, max_iter,
              min_frac):
    result = {
        "status": 0,
        "error_msg": "success",
        "result": {
            "input_params": {
                "source_table": source_table,
                "out_table_model": out_table_model,
                "out_table_result": out_table_result,
                "id_col": id_col,
                "feature_cols": feature_cols,
                "k": k,
                "fn_dist": fn_dist,
                "agg_centroid": agg_centroid,
                "max_iter": max_iter,
                "min_frac": min_frac
            },
            "output_params": [
            ]
        }
    }
    randName = random.sample('zyxwvutsrqponmlkjihgfedcba', 7);
    matTable = "pipeline.tmp_vector_{}_{}".format("".join(randName), int(time.time()))
    sourceMeta = getTableMetaInfo(source_table)
    sourceKeys = set(sourceMeta.keys())
    features = feature_cols.split(",")
    featureKeys = set(features)
    otherCols = list(sourceKeys - featureKeys)
    otherCols.remove(id_col)
    flag, msg = columnsToVecTable(source_table, feature_cols, ",".join(otherCols), matTable)
    if not flag:
        plpy.execute("DROP TABLE IF EXISTS {}_summary".format(matTable))
        plpy.execute("DROP TABLE IF EXISTS {}".format(matTable))
        return logError(msg, 500, result)
    model_table = out_table_model

    # features = feature_cols.split(",")
    model_sql = """
        CREATE TABLE {} as SELECT * FROM madlib.kmeanspp('{}', '{}', {}, '{}', '{}', {}, {})
    """.format(model_table, matTable, 'feature_vector', k, fn_dist, agg_centroid, max_iter, min_frac)
    try:
        rs = plpy.execute(model_sql)
        meta = getTableMetaInfo(model_table)
        tmpTable = {
            "out_table_name": model_table,
            "output_cols": meta.keys()
        }
        result["result"]["output_params"].append(tmpTable)
        plpy.execute("DROP TABLE IF EXISTS {}".format(model_table + "_summary"))
    except Exception as e:
        plpy.execute("DROP TABLE IF EXISTS {}".format(model_table))
        plpy.execute("DROP TABLE IF EXISTS {}".format(model_table + "_summary"))
        plpy.execute("DROP TABLE IF EXISTS {}".format(matTable))
        plpy.execute("DROP TABLE IF EXISTS {}_summary".format(matTable))
        result["status"] = 500
        result["error_msg"] = str(e) + " sql=" + model_sql
        return json.dumps(result)

    selectBody = []
    for i in range(len(features)):
        selectBody.append('data.feature_vector[{}] as "{}"'.format(i + 1, features[i]))
    for col in otherCols:
        selectBody.append('data.{} as {}'.format(col, col))

    resultTable = out_table_result
    resultSql = "CREATE TABLE {} AS SELECT {}, concat((madlib.closest_column(centroids, {}, '{}')).column_id) as cluster_id, row_number() over() as _record_id_ from {} as data, {} ORDER BY {}".format(
        resultTable, ",".join(selectBody), 'feature_vector', fn_dist, matTable, model_table, id_col)

    try:
        rs = plpy.execute(resultSql)
        meta = getTableMetaInfo(resultTable)
        tmpTable = {
            "out_table_name": resultTable,
            "output_cols": meta.keys()
        }
        result["result"]["output_params"].append(tmpTable)
        plpy.execute("DROP TABLE IF EXISTS {}".format(matTable))
        plpy.execute("DROP TABLE IF EXISTS {}_summary".format(matTable))
    except Exception as e:
        plpy.execute("DROP TABLE IF EXISTS {}".format(model_table))
        plpy.execute("DROP TABLE IF EXISTS {}".format(resultTable))
        plpy.execute("DROP TABLE IF EXISTS {}".format(matTable))
        plpy.execute("DROP TABLE IF EXISTS {}_summary".format(matTable))
        result["status"] = 500
        result["error_msg"] = str(e) + " sql=" + resultSql
        return json.dumps(result)
    finally:
        plpy.execute("DROP TABLE IF EXISTS {}".format(matTable))

    outTables = []
    outTables.append(result["result"]["output_params"][1])
    outTables.append(result["result"]["output_params"][0])
    result["result"]["output_params"] = outTables
    return json.dumps(result)


if __name__ == "__main__":
    ret = begin_alg(source_table, out_table_model, out_table_result, id_col, feature_cols, k, fn_dist, agg_centroid,
                    max_iter, min_frac)
    return ret

$BODY$
  LANGUAGE plpythonu VOLATILE
  COST 100